In [53]:
import pickle
import jax

import matplotlib.pyplot as plt
import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp
from scipy.stats import gaussian_kde
import plotly.express as px
import pandas as pd
import pickle
tfd = tfp.distributions
import plotly
plotly.offline.init_notebook_mode()
In [54]:
x = jnp.linspace(0,6,10000)
In [55]:
with open('./results_data/linear_regression_Ajax','rb') as f:
    variational  = pickle.load(f)
In [56]:
params = variational.get_params()
loc_m, scale = jax.tree_leaves(variational.transform_dist(params['theta']))
scale = jnp.dot(scale, scale.T)
In [57]:
loc_m,scale
Out[57]:
(DeviceArray([3.7587612, 2.1869001], dtype=float32),
 DeviceArray([[ 0.13721034, -0.04376228],
              [-0.04376228,  0.03969986]], dtype=float32))
In [58]:
all_pdf = []
for i in range(2):
    y = tfd.Normal(loc = loc_m[i],scale = jnp.sqrt(scale[i][i])).prob(x)
    all_pdf.append(y)
In [59]:
with open('./results_data/linear_regression_laplace','rb') as f:
    laplace = pickle.load(f)
In [60]:
loc_m = laplace['mean']
std = jnp.sqrt(jnp.diag(laplace['cov']))
In [61]:
for i in range(2):
    y = tfd.Normal(loc = loc_m[i],scale = std[i]).prob(x)
    all_pdf.append(y)
In [62]:
with open('./results_data/MCMC_Blackjax','rb') as f:
      black_samples = pickle.load(f)
In [63]:
for i in range(2):
    kde_black = gaussian_kde(black_samples.position['theta'][:,i],bw_method=0.3)
    pdf_black = kde_black(x)
    all_pdf.append(pdf_black)
In [64]:
all_label = ['Ajax VI theta0']*x.shape[0] + ['Ajax VI theta1']*x.shape[0] + ['Laplace theta0']*x.shape[0] + ['Laplace theta1']*x.shape[0] +['MCMC theta0']*x.shape[0]+['MCMC theta1']*x.shape[0]
In [65]:
all_pdf = jnp.array(all_pdf).reshape((-1))
In [66]:
x_repeated = jnp.tile(x,6)
to_df = {
    "theta":x_repeated,
    "PDF":all_pdf,
    "label": all_label

}
df = pd.DataFrame(to_df)

fig = px.line(to_df,"theta","PDF",color="label",title="logistic regression") 
fig.show()
fig.write_html("logistic_reg_result_plotly.html")
In [ ]: